a56e1c
@@ -61,6 +61,7 @@
 import org.apache.hadoop.hive.ql.plan.AggregationDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
 import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
+import org.apache.hadoop.hive.ql.plan.ExprNodeDescUtils;
 import org.apache.hadoop.hive.ql.plan.GroupByDesc;
 import org.apache.hadoop.hive.ql.plan.JoinDesc;
 import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
@@ -695,12 +696,14 @@
private static void pruneReduceSinkOperator(boolean[] retainFlags,
       ReduceSinkOperator reduce, ColumnPrunerProcCtx cppCtx) throws SemanticException {
     ReduceSinkDesc reduceConf = reduce.getConf();
     Map<String, ExprNodeDesc> oldMap = reduce.getColumnExprMap();
+    LOG.info("RS " + reduce.getIdentifier() + " oldColExprMap: " + oldMap);
     RowResolver oldRR = cppCtx.getOpToParseCtxMap().get(reduce).getRowResolver();
     ArrayList<ColumnInfo> signature = oldRR.getRowSchema().getSignature();
 
     List<String> valueColNames = reduceConf.getOutputValueColumnNames();
     ArrayList<String> newValueColNames = new ArrayList<String>();
 
+    List<ExprNodeDesc> keyExprs = reduceConf.getKeyCols();
     List<ExprNodeDesc> valueExprs = reduceConf.getValueCols();
     ArrayList<ExprNodeDesc> newValueExprs = new ArrayList<ExprNodeDesc>();
 
@@ -713,10 +716,16 @@
private static void pruneReduceSinkOperator(boolean[] retainFlags,
           outputCol = Utilities.ReduceField.VALUE.toString() + "." + outputCol;
           nm = oldRR.reverseLookup(outputCol);
         }
-        ColumnInfo colInfo = oldRR.getFieldMap(nm[0]).remove(nm[1]);
-        oldRR.getInvRslvMap().remove(colInfo.getInternalName());
-        oldMap.remove(outputCol);
-        signature.remove(colInfo);
+
+        // Only remove information of a column if it is not a key,
+        // i.e. this column is not appearing in keyExprs of the RS
+        if (ExprNodeDescUtils.indexOf(outputColExpr, keyExprs) == -1) {
+          ColumnInfo colInfo = oldRR.getFieldMap(nm[0]).remove(nm[1]);
+          oldRR.getInvRslvMap().remove(colInfo.getInternalName());
+          oldMap.remove(outputCol);
+          signature.remove(colInfo);
+        }
+
       } else {
         newValueColNames.add(outputCol);
         newValueExprs.add(outputColExpr);
@@ -729,6 +738,7 @@
private static void pruneReduceSinkOperator(boolean[] retainFlags,
         .getFieldSchemasFromColumnList(reduceConf.getValueCols(),
         newValueColNames, 0, ""));
     reduceConf.setValueSerializeInfo(newValueTable);
+    LOG.info("RS " + reduce.getIdentifier() + " newColExprMap: " + oldMap);
   }
 
   /**
